[ROCm] Fix biased wgrad with fp32 gradient accumulation#634
Open
XinyuJiangCMU wants to merge 5 commits into
Open
[ROCm] Fix biased wgrad with fp32 gradient accumulation#634XinyuJiangCMU wants to merge 5 commits into
XinyuJiangCMU wants to merge 5 commits into
Conversation
On ROCm, hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad GEMM that also fuses the bias-gradient (BGRADB) epilogue: the heuristic returns zero algorithms and the GEMM raises "Unable to find any suitable algorithms". This hits any LayerNormLinear with bias (e.g. Qwen2.5 QKV with add-qkv-bias) when training with fp32 gradient accumulation (--accumulate-allreduce-grads-in-fp32). When wgrad is accumulated into an fp32 main_grad on ROCm, skip the fused dbias and reduce grad_bias separately (grad_output.sum over tokens in fp32, cast to bias dtype) -- mathematically identical to the BGRADB epilogue. CUDA and all other paths are unchanged. Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com> Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Move the BGRADB-unfuse workaround from the per-module LayerNormLinear backward up to general_gemm, the single chokepoint every wgrad path funnels through. This covers Linear, LayerNormLinear, LayerNormMLP and the delayed-wgrad store in one place, and fixes the delayed-wgrad path that the per-module version dropped the bias gradient on. CUDA, the forward bias-add path and fp8/fp4 are untouched. Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com> Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover the non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the ROCm numerics test that was skipped for this case. Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com> Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the Linear / LayerNormLinear / LayerNormMLP wgrad numerics tests skipped for this case; GroupedLinear routes through general_grouped_gemm and stays skipped. Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com> Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
On ROCm, hipBLASLt cannot find a suitable algorithm for an fp32 weight gradient GEMM with fused bias gradient computation.
This causes training with --add-qkv-bias and --accumulate-allreduce-grads-in-fp32 to fail with:
Fix
Run the weight gradient GEMM without the fused bias gradient and compute the bias gradient separately by summing grad_output.
The fix is implemented in general_gemm, covering delayed weight gradient execution and other callers using the same path. CUDA behavior is unchanged.
Testing
Verified on MI350X:
The isolated reproduction passes.
Qwen2.5-0.5B GSM8K training passes the previously failing backward step.
Re-enabled the ROCm wgrad numerics tests previously skipped by grouped GEMM change 434:
All three use general_gemm.
Result: